# -*- coding: utf-8 -*-
"""RNN-IMDB.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1xHNofxQyDRb1e8vSEKs8bJOzvHQQGdJV
"""

!pip install torch==2.0.1 torchtext==0.15.2 torchdata==0.6.1 portalocker==2.7.0
from torchtext.datasets import IMDB
full_dataset = list(IMDB(split='train'))
print(full_dataset[0])
import torch
from torch import nn
import torch.optim as optim
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import random

# 1. 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
# 2. Tokenizer 與 Vocabulary
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for label, text in data_iter:
        yield tokenizer(text)

# 下載資料（會自動分割 train/test）
train_iter, test_iter = IMDB(split=('train', 'test'))
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<pad>", "<unk>"])
vocab.set_default_index(vocab["<unk>"])

# 3. 處理資料
text_pipeline =  lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: 1 if x == 'pos' else 0

def collate_batch(batch):
    label_list, text_list = [], []
    for label, text in batch:
        label_list.append(torch.tensor(label_pipeline(label), dtype=torch.float32))
        processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64)
        text_list.append(processed_text)
    text_list = pad_sequence(text_list, padding_value=vocab['<pad>'])
    return text_list.to(device), torch.tensor(label_list, dtype=torch.float32).to(device)

# 4. 建立 RNN 模型
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab['<pad>'])
        self.rnn = nn.RNN(embed_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, text):
        embedded = self.embedding(text)
        output, hidden = self.rnn(embedded)
        return self.sigmoid(self.fc(hidden.squeeze(0)))

# 5. 建立 Dataloader
train_iter, test_iter = IMDB(split=('train', 'test'))
train_dataloader = DataLoader(list(train_iter), batch_size=32, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(list(test_iter), batch_size=32, shuffle=False, collate_fn=collate_batch)

# 6. 初始化模型
vocab_size = len(vocab)
model = RNNModel(vocab_size, embed_dim=64, hidden_dim=128).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 7. 訓練模型
for epoch in range(10):
    model.train()
    total_loss = 0
    for text, labels in train_dataloader:
        optimizer.zero_grad()
        output = model(text).squeeze()
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# 8. 測試模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for text, labels in test_dataloader:
        outputs = model(text).squeeze()
        predicted = (outputs >= 0.5).float()
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Accuracy: {correct / total:.4f}")